import math, re, hashlib
from typing import Union, Optional, Dict, Tuple, Callable 

import numba
import numpy as np
try:
    import sympy as sp
except Exception:
    sp = None

# Local imports from the package.  When installed as part of the
# ``dae`` package, ``graph_utils`` is a subpackage.  Use relative
# imports so that this module functions whether executed from the
# repository or installed via pip.
from .graph_utils import build_weighted_graph
from .graph_utils.kernels import (
    default_kernel,              # Q(r^2)
    DEFAULT_KERNEL_DEFAULT_PARAMS,
    make_builtin_kernel,
    builtin_kernel_names,
)
from .graph_utils.spectral import init_spectral

_EPOCH_KERNELS = {}  # cache: (opt_code, dim_code, has_clip) -> compiled kernel

_FORCE_CACHE: Dict[Tuple[str, Tuple[str, ...], str], object] = {}
_FORCE_SIGNATURE = numba.float32(numba.float32, numba.float32[:])
# ------ s-based heavy-tail defaults for optimizer ------
# Q(s) = (1 + a s)^(-b)
# d/ds log Q(s) = - a b / (1 + a s)              # used for attraction
# - d/ds Q(s)   =   a b (1 + a s)^(-(b+1))       # used for repulsion


@numba.njit(_FORCE_SIGNATURE, fastmath=True, cache=True, nogil=True)
def _ht_Q_s(s, p):
    a = p[0]; b = p[1]
    return math.pow(1.0 + a * s, -b)

@numba.njit(_FORCE_SIGNATURE, fastmath=True, cache=True, nogil=True)
def _ht_dlogQ_s(s, p):
    a = p[0]; b = p[1]
    return -(a * b) / (1.0 + a * s)

@numba.njit(_FORCE_SIGNATURE, fastmath=True, cache=True, nogil=True)
def _ht_negdQ_s(s, p):
    a = p[0]; b = p[1]
    den = 1.0 + a * s
    return (a * b) * math.pow(den, -(b + 1.0))



# --- Embedding metric: s and grad s ---
_EM_SGRAD_SIG = numba.float32(
    numba.float32[:],  # y_i
    numba.float32[:],  # y_j
    numba.float32[:],  # grad_out
    numba.float32[:]   # params (e.g., [eps, p])
)

@numba.njit(_EM_SGRAD_SIG, fastmath=True, cache=True, nogil=True)
def em_euclid2_sgrad(y_i, y_j, grad_out, params):
    s = 0.0
    for d in range(y_i.shape[0]):
        dv = y_i[d] - y_j[d]
        s += dv * dv
        grad_out[d] = 2.0 * dv
    if s < 1e-12:
        s = 1e-12
    return s

@numba.njit(_EM_SGRAD_SIG, fastmath=True, cache=True, nogil=True)
def em_euclid_sgrad(y_i, y_j, grad_out, params):
    eps = params[0]
    s2 = 0.0
    for d in range(y_i.shape[0]):
        dv = y_i[d] - y_j[d]
        s2 += dv * dv
    s = math.sqrt(s2) + eps
    inv = 1.0 / s
    for d in range(y_i.shape[0]):
        grad_out[d] = (y_i[d] - y_j[d]) * inv
    return s

@numba.njit(_EM_SGRAD_SIG, fastmath=True, cache=True, nogil=True)
def em_l1_sgrad(y_i, y_j, grad_out, params):
    s = 0.0
    for d in range(y_i.shape[0]):
        dv = y_i[d] - y_j[d]
        s += abs(dv)
        grad_out[d] = 0.0 if dv == 0.0 else (1.0 if dv > 0.0 else -1.0)
    return s

@numba.njit(_EM_SGRAD_SIG, fastmath=True, cache=True, nogil=True)
def em_lp_sgrad(y_i, y_j, grad_out, params):
    eps = params[0]
    p = params[1]
    s_p = 0.0
    for d in range(y_i.shape[0]):
        dv = y_i[d] - y_j[d]
        s_p += math.pow(abs(dv) + eps, p)
        grad_out[d] = dv  # temp store
    s = math.pow(s_p, 1.0/p)
    denom = math.pow(s + eps, p - 1.0)
    for d in range(y_i.shape[0]):
        dv = grad_out[d]
        ad = abs(dv) + eps
        grad_out[d] = (dv * math.pow(ad, p - 2.0)) / denom
    return s

@numba.njit(_EM_SGRAD_SIG, fastmath=True, cache=True, nogil=True)
def em_corr_sgrad(y_i, y_j, grad_out, params):
    # params[0] = eps
    eps = params[0]
    n = y_i.shape[0]
    mu_i = 0.0; mu_j = 0.0
    for d in range(n):
        mu_i += y_i[d]; mu_j += y_j[d]
    mu_i /= n; mu_j /= n

    dot = 0.0; ui2 = 0.0; vj2 = 0.0
    for d in range(n):
        u = y_i[d] - mu_i
        v = y_j[d] - mu_j
        dot += u * v
        ui2 += u * u
        vj2 += v * v

    Bu = math.sqrt(ui2) + eps
    Bv = math.sqrt(vj2) + eps
    denom = Bu * Bv
    corr = dot / denom

    sum_g = 0.0
    for d in range(n):
        u = y_i[d] - mu_i
        v = y_j[d] - mu_j
        gud = v / denom - corr * u / (Bu * Bu)
        grad_out[d] = gud
        sum_g += gud
    mean_g = sum_g / n
    for d in range(n):
        grad_out[d] = -(grad_out[d] - mean_g)
    return 1.0 - corr

def _resolve_em_sgrad(name: str | None, custom_fn=None):
    """Choose the sgrad implementation."""
    if custom_fn is not None:
        # If already jitted, return as is; otherwise compile.
        return custom_fn if hasattr(custom_fn, 'signatures') else numba.njit(_EM_SGRAD_SIG, cache=True, fastmath=True, nogil=True)(custom_fn)
    key = (name or "l2sqr").lower()
    if key in ("l2sqr", "sqeuclidean", "euclid2"):
        return em_euclid2_sgrad
    if key in ("l2", "euclidean"):
        return em_euclid_sgrad
    if key in ("l1", "manhattan"):
        return em_l1_sgrad
    if key in ("lp",):
        return em_lp_sgrad
    if key in ("corr", "correlation"):
        return em_corr_sgrad
    return em_euclid2_sgrad


def _resolve_force(obj, default_fn, free_syms: Tuple[str, ...], base_name: str):
    """
    Accepts: None | str (built-in name or expression) | sympy.Expr | python callable
    Returns: numba-compiled function with signature (float32, float32[:]) -> float32
    """
    if obj is None:
        return default_fn
    if isinstance(obj, (str, sp.Expr)):
        # If it's a built-in family *name*, we don't compile an expression here.
        # Built-ins are handled in _select_kernels_and_params.
        if isinstance(obj, str) and obj.lower() in builtin_kernel_names():
            # Caller will replace using make_builtin_kernel
            return default_fn
        # Otherwise treat as an expression and compile
        if sp is None:
            import sympy as sp  # import only if/when needed
        return _compile_force_function(obj, free_syms, base_name)
    # python callable: wrap with numba to ensure proper signature & speed
    return numba.njit(_FORCE_SIGNATURE, fastmath=True, cache=True, nogil=True)(obj)

def _compile_force_function(expr: Union[str, "sp.Expr"], free_syms: Tuple[str, ...], base_name: str = "fn"):
    """
    Compile a function f(s, params[:]) -> float32 from a SymPy expr or Python string.
    Replaces free symbols with param slots (p[0], p[1], ...).
    """
    code = sp.printing.pycode(expr) if isinstance(expr, sp.Expr) else str(expr)
    cache_key = (code, free_syms, base_name)
    if cache_key in _FORCE_CACHE:
        return _FORCE_CACHE[cache_key]

    for idx, sym in enumerate(free_syms):
        code = re.sub(rf"\b{re.escape(sym)}\b", f"p[{idx}]", code)

    fn_name = f"{base_name}_{hashlib.sha1(code.encode()).hexdigest()[:12]}"
    # NOTE: argument is 's' to reinforce s-based contract
    src = f"def {fn_name}(s, p):\n    return {code}\n"

    ns = {}
    exec(compile(src, "<ctmc_kernel>", "exec"), {"math": math, "np": np}, ns)
    fn_py = ns[fn_name]
    fn_jit = numba.njit(_FORCE_SIGNATURE, fastmath=True, cache=True, nogil=True)(fn_py)
    _FORCE_CACHE[cache_key] = fn_jit
    return fn_jit

@numba.njit(inline='always', fastmath=True, cache=True, nogil=True)
def _sgd_apply_pair_vecclip(Y, i, j, grad_buf, g_scalar, lr, clip_norm):
    """
    Apply symmetric pair update (i +=, j -=) for SGD using a gradient vector
    stored in grad_buf as ∂s/∂y_i, scaled by g_scalar (e.g., d/ds log Q or row_scale*dQ/ds).
    Uses vector-norm clipping with threshold clip_norm (0 disables).
    """
    dim = grad_buf.shape[0]
    # compute ||g|| where g = g_scalar * grad_buf
    g2 = 0.0
    for d in range(dim):
        v = g_scalar * grad_buf[d]
        g2 += v * v
    if clip_norm > 0.0:
        gn = math.sqrt(g2) + 1e-12
        if gn > clip_norm:
            g_scalar *= (clip_norm / gn)
    # apply
    for d in range(dim):
        v = g_scalar * grad_buf[d]
        Y[i, d] += lr * v
        Y[j, d] -= lr * v


@numba.njit(fastmath=True, cache=True, nogil=True)
def _recenter_and_rms(Y, target_rms):
    """
    Make embedding zero-mean per dimension and scale global RMS to target_rms (if >0).
    This fights drift and slow blow-up without distorting relative geometry much.
    """
    n, dim = Y.shape
    mu = np.zeros(dim, dtype=np.float32)
    for d in range(dim):
        s = 0.0
        for i in range(n):
            s += Y[i, d]
        mu[d] = s / n

    scale = 1.0
    if target_rms > 0.0:
        var = 0.0
        for d in range(dim):
            md = mu[d]
            for i in range(n):
                dv = Y[i, d] - md
                var += dv * dv
        rms = math.sqrt(var / (n * dim)) + 1e-12
        if rms > 0.0:
            scale = target_rms / rms

    for d in range(dim):
        md = mu[d]
        for i in range(n):
            Y[i, d] = (Y[i, d] - md) * scale


@numba.njit(inline='always', fastmath=True, cache=True, nogil=True)
def _linear_warmup(epoch, warm_epochs):
    """
    Linear ramp from 0->1 over warm_epochs. epoch is 0-based.
    """
    if warm_epochs <= 0.0:
        return 1.0
    e = float(epoch) + 1.0
    if e >= warm_epochs:
        return 1.0
    return e / warm_epochs



@numba.njit(inline="always", fastmath=True, cache=True, nogil=True)
def _clamp_scalar(g, clip):
    if clip <= 0.0:
        return g
    return -clip if g < -clip else (clip if g > clip else g)

@numba.njit(inline="always", fastmath=True, cache=True, nogil=True)
def _rmsprop_fused(g, v_old, lr, rho, eps):
    v_new = rho * v_old + (1.0 - rho) * g * g
    return lr * g / (math.sqrt(v_new) + eps), v_new

@numba.njit(inline="always", fastmath=True, cache=True, nogil=True)
def _adam_fused_corr(g, m_old, v_old, lr, b1, b2, eps, corr1, corr2):
    m_new = b1 * m_old + (1.0 - b1) * g
    v_new = b2 * v_old + (1.0 - b2) * g * g
    # Apply precomputed bias corrections
    m_hat = m_new * corr1
    v_hat = v_new * corr2
    return lr * m_hat / (math.sqrt(v_hat) + eps), m_new, v_new

@numba.njit(inline="always", fastmath=True, cache=True, nogil=True)
def _adafactor_fused(g, v_old, lr, rho, eps):
    v_new = rho * v_old + (1.0 - rho) * g * g
    u = lr * g / (math.sqrt(v_new) + eps)
    return u, v_new

@numba.njit(inline="always")
def fast_randint(k, epoch, neg_idx, n_pts, seed):
    """PCG-like integer RNG; fast and deterministic given inputs."""
    state = np.uint64(k) ^ (np.uint64(epoch) << 16) ^ (np.uint64(neg_idx) << 32) ^ np.uint64(seed)
    state = state * np.uint64(6364136223846793005) + np.uint64(1442695040888963407)
    xorshifted = np.uint32(((state >> 18) ^ state) >> 27)
    rot = np.int32(state >> 59)
    result = (xorshifted >> rot) | (xorshifted << ((-rot) & 31))
    return int(result % np.uint32(n_pts))

def _compute_intervals(weights, n_epochs):
    # Vectorized, no JIT
    w = weights.astype(np.float32, copy=False)
    w_max = float(w.max())
    if w_max <= 0.0:
        return np.full_like(w, -1.0, dtype=np.float32)
    out = np.full_like(w, -1.0, dtype=np.float32)
    mask = w > 0.0
    out[mask] = w_max / w[mask]  # same as n_epochs / (n_epochs * w / w_max)
    return out

# --- Helper to fit (a,b) for UMAP-like min-dist heavy tail ---
def _fit_ab_from_min_dist_heavytail(spread: float, min_dist: float, embed_metric_name: str) -> Tuple[np.float32, np.float32]:
    """
    Fit (a,b) for Q(s) = (1 + a*s)^(-b) so that Q(s(d)) ≈ target(d), where
        target(d) = 1 if d < min_dist else exp(-(d - min_dist)/spread).

    If embed_metric is 'l2sqr' we use s = d^2; otherwise s = d.
    """
    xv = np.linspace(0.0, float(spread) * 3.0, 300, dtype=np.float32)

    yv = np.ones_like(xv, dtype=np.float32)
    mask = xv >= float(min_dist)
    yv[mask] = np.exp(-(xv[mask] - float(min_dist)) / float(spread)).astype(np.float32)

    # Map to the optimizer’s s
    metric_key = (embed_metric_name or "").lower()
    if metric_key in ("l2sqr", "sqeuclidean", "euclid2"):
        s = (xv * xv)
    else:
        s = xv

    # Lightweight grid search (keeps dependencies minimal)
    a_grid = np.geomspace(0.1, 100.0, num=60).astype(np.float32)
    b_grid = np.linspace(0.1, 5.0, num=60).astype(np.float32)

    best_a, best_b, best_err = np.float32(1.0), np.float32(1.0), 1e30
    for a in a_grid:
        base = 1.0 + a * s
        for b in b_grid:
            pred = np.power(base, -b).astype(np.float32, copy=False)
            err = float(np.mean((pred - yv) ** 2))
            if err < best_err:
                best_err = err
                best_a, best_b = a, b
    return best_a, best_b

@numba.njit(inline='always')
def _igcd(a, b):
    """Integer GCD (for stride selection); works in numba nopython mode."""
    if a < 0: a = -a
    if b < 0: b = -b
    while b != 0:
        tmp = a % b
        a = b
        b = tmp
    return a

def _make_epoch_kernel(opt_code: int, dim_code: int, has_clip: int):
    OC = int(opt_code)   # 0 sgd, 1 rmsprop, 2 adam, 3 adafactor
    DC = int(dim_code)   # kept for interface compatibility (unused here)
    CL = int(has_clip)   # 0/1

    @numba.njit(parallel=True, fastmath=True, cache=True, nogil=True)
    def _epoch_kernel(
        Y, buf1, buf2,
        ei, ej, weights, row_scale,
        ui, us, uir, usr,
        params, f_attr, f_rep,
        hp_arr,
        lr0,
        n_epochs,
        block_size,
        rng_seed,
        # embedding metric
        em_sgrad_fn,     # ∂s/∂y_i provider
        em_params,       # float32[:] (e.g. [eps, p])
        # tracing
        trace_idx, trace_buf, snap_map
    ):
        OC_loc = int(opt_code)
        n_edges = int(ei.shape[0])
        n_pts, dim = Y.shape

        # Base hyperparams (optimizer-specific; unchanged positions 0..3)
        hp1, hp2, hp3, hp4 = hp_arr[0], hp_arr[1], hp_arr[2], hp_arr[3]

        for epoch in range(n_epochs):
            t = epoch + 1

            # Adam bias corrections (unchanged)
            if OC_loc == 2:
                corr1 = 1.0 / (1.0 - hp1 ** float(t))
                corr2 = 1.0 / (1.0 - hp2 ** float(t))
            else:
                corr1 = 1.0
                corr2 = 1.0

            # Cosine LR decay (unchanged)
            lr = lr0 * 0.5 * (1.0 + math.cos(math.pi * epoch / max(1, n_epochs - 1)))

            # --- extras (backward compatible if missing) ---
            EE      = hp_arr[4] if hp_arr.shape[0] > 4 else 1.0   # early exaggeration
            EE_ep   = hp_arr[5] if hp_arr.shape[0] > 5 else 0.0   # exaggeration epochs
            K_neg   = int(hp_arr[6]) if hp_arr.shape[0] > 6 else 1  # hard negative candidates
            do_shuf = (hp_arr[7] > 0.5) if hp_arr.shape[0] > 7 else False  # shuffle edges

            # Exaggeration schedule (linear decay)
            if EE_ep > 0.0:
                efrac = float(epoch) / EE_ep
                exag = 1.0 + (EE - 1.0) * (1.0 - (efrac if efrac < 1.0 else 1.0))
            else:
                exag = 1.0

            # --- unbiased edge ordering with stride permutation ---
            if do_shuf and n_edges > 1:
                ne = n_edges
                s = 2 * (epoch + 1) + 1  # odd
                if s >= ne:
                    s = (s % (ne - 1)) + 1
                    if (s & 1) == 0:
                        s += 1
                        if s >= ne:
                            s = 1
                attempts = 0
                while _igcd(s, ne) != 1 and attempts < 16:
                    s += 2
                    if s >= ne:
                        s = 1
                    attempts += 1
                stride = int(s if attempts < 16 else 1)
                offset = int((epoch * 2654435761) % ne)
            else:
                stride = 1
                offset = 0

            n_blocks = (n_edges + int(block_size) - 1) // int(block_size)

            for b in numba.prange(n_blocks):
                start = b * int(block_size)
                end   = min(start + int(block_size), n_edges)

                for k_raw in range(start, end):
                    # apply permutation (ensure k is int)
                    k = int((offset + (k_raw * stride) % n_edges) % n_edges)

                    if us[k] > epoch:
                        continue

                    # ---- Attraction ----
                    i = int(ei[k])  # <-- force int index
                    j = int(ej[k])  # <-- force int index

                    grad_buf = np.empty(dim, dtype=np.float32)
                    s = em_sgrad_fn(Y[i, :], Y[j, :], grad_buf, em_params)
                    ga = f_attr(s, params) * exag

                    # pair-vector norm clipping for ALL optimizers
                    clip_norm = hp4 if CL == 1 else 0.0
                    if clip_norm > 0.0:
                        g2 = 0.0
                        for d in range(dim):
                            v = ga * grad_buf[d]
                            g2 += v * v
                        gn = math.sqrt(g2) + 1e-12
                        if gn > clip_norm:
                            ga *= (clip_norm / gn)

                    if OC_loc == 0:  # SGD
                        _sgd_apply_pair_vecclip(Y, i, j, grad_buf, ga, lr, 0.0)  # already clipped
                    elif OC_loc == 1:  # RMSProp
                        for d in range(dim):
                            g = ga * grad_buf[d]
                            ui_, buf1[i, d] = _rmsprop_fused( g,  buf1[i, d], lr, hp1, hp2)
                            uj_, buf1[j, d] = _rmsprop_fused(-g,  buf1[j, d], lr, hp1, hp2)
                            Y[i, d] += ui_; Y[j, d] += uj_
                    elif OC_loc == 2:  # Adam
                        for d in range(dim):
                            g = ga * grad_buf[d]
                            ui_, buf1[i, d], buf2[i, d] = _adam_fused_corr( g,  buf1[i, d], buf2[i, d], lr, hp1, hp2, hp3, corr1, corr2)
                            uj_, buf1[j, d], buf2[j, d] = _adam_fused_corr(-g,  buf1[j, d], buf2[j, d], lr, hp1, hp2, hp3, corr1, corr2)
                            Y[i, d] += ui_; Y[j, d] += uj_
                    else:  # Adafactor
                        for d in range(dim):
                            g = ga * grad_buf[d]
                            ui_, buf1[i, d] = _adafactor_fused( g,  buf1[i, d], lr, hp1, hp2)
                            uj_, buf1[j, d] = _adafactor_fused(-g,  buf1[j, d], lr, hp1, hp2)
                            Y[i, d] += ui_; Y[j, d] += uj_

                    us[k] += ui[k]

                    # ---- Repulsion (negative sampling) ----
                    overdue = epoch - usr[k]
                    n_neg   = int(overdue / uir[k])
                    if n_neg > 0:
                        for neg_idx in range(n_neg):
                            # hard-negative mining with K candidates
                            best_l = -1
                            best_s = 1e30

                            if K_neg <= 1:
                                l = fast_randint(k, epoch, neg_idx, n_pts, rng_seed)
                                if l == i or l == j:
                                    continue
                                best_l = l
                            else:
                                # search among K_neg candidates
                                for cand in range(K_neg):
                                    l_c = fast_randint(k, epoch, neg_idx + cand * 17, n_pts, rng_seed)
                                    if l_c == i or l_c == j:
                                        continue
                                    tmp_grad = np.empty(dim, dtype=np.float32)
                                    s_c = em_sgrad_fn(Y[i, :], Y[int(l_c), :], tmp_grad, em_params)
                                    if s_c < best_s:
                                        best_s = s_c
                                        best_l = l_c

                            if best_l < 0:
                                continue

                            bl = int(best_l)  # <-- force int index

                            grad_buf = np.empty(dim, dtype=np.float32)
                            s_r = em_sgrad_fn(Y[i, :], Y[bl, :], grad_buf, em_params)
                            gr  = f_rep(s_r, params)
                            g_scalar = row_scale[i] * gr

                            # pair-vector clip repulsion too
                            if clip_norm > 0.0:
                                g2 = 0.0
                                for d in range(dim):
                                    v = g_scalar * grad_buf[d]
                                    g2 += v * v
                                gn = math.sqrt(g2) + 1e-12
                                if gn > clip_norm:
                                    g_scalar *= (clip_norm / gn)

                            if OC_loc == 0:  # SGD
                                rep_ramp = _linear_warmup(epoch, hp1)  # existing warmup
                                _sgd_apply_pair_vecclip(Y, i, bl, grad_buf, g_scalar * rep_ramp, lr, 0.0)
                            elif OC_loc == 1:  # RMSProp
                                for d in range(dim):
                                    g = g_scalar * grad_buf[d]
                                    ui_, buf1[i,  d] = _rmsprop_fused( g,  buf1[i,  d], lr, hp1, hp2)
                                    ul_, buf1[bl, d] = _rmsprop_fused(-g,  buf1[bl, d], lr, hp1, hp2)
                                    Y[i,  d] += ui_; Y[bl, d] += ul_
                            elif OC_loc == 2:  # Adam
                                for d in range(dim):
                                    g = g_scalar * grad_buf[d]
                                    ui_, buf1[i,  d], buf2[i,  d]  = _adam_fused_corr( g,  buf1[i,  d],  buf2[i,  d],  lr, hp1, hp2, hp3, corr1, corr2)
                                    ul_, buf1[bl, d], buf2[bl, d]  = _adam_fused_corr(-g,  buf1[bl, d],  buf2[bl, d],  lr, hp1, hp2, hp3, corr1, corr2)
                                    Y[i,  d] += ui_; Y[bl, d] += ul_
                            else:  # Adafactor
                                for d in range(dim):
                                    g = g_scalar * grad_buf[d]
                                    ui_, buf1[i,  d] = _adafactor_fused( g,  buf1[i,  d], lr, hp1, hp2)
                                    ul_, buf1[bl, d] = _adafactor_fused(-g,  buf1[bl, d], lr, hp1, hp2)
                                    Y[i,  d] += ui_; Y[bl, d] += ul_

                        usr[k] += n_neg * uir[k]

                # tracing (unchanged)
                snap_id = snap_map[epoch]
                if snap_id >= 0:
                    n_trace = trace_idx.shape[0]
                    for t2 in range(n_trace):
                        ii = int(trace_idx[t2])   # <-- ensure int index
                        for d in range(dim):
                            trace_buf[snap_id, t2, d] = Y[ii, d]

            # Periodic recenter/RMS (SGD only; unchanged)
            if OC_loc == 0 and hp2 > 0.0:
                if ((epoch + 1) % int(hp2)) == 0:
                    _recenter_and_rms(Y, target_rms=hp3)
        return

    return _epoch_kernel


def _get_epoch_kernel(opt_code: int, dim: int, clip_enabled: bool):
    dim_code = 2 if dim == 2 else (3 if dim == 3 else 0)
    has_clip = 1 if clip_enabled else 0
    key = (opt_code, dim_code, has_clip)
    if key not in _EPOCH_KERNELS:
        _EPOCH_KERNELS[key] = _make_epoch_kernel(opt_code, dim_code, has_clip)
    return _EPOCH_KERNELS[key]


def _metric_to_code(metric: str) -> int:
    """
    Map your graph metric name to the mid_shortcuts metric_code.
      0 = squared Euclidean, 1 = cosine, 2 = L1, 3 = correlation
    """
    m = (metric or "euclidean").lower()
    if m in ("cosine",):
        return 1
    if m in ("manhattan", "l1"):
        return 2
    if m in ("correlation", "corr"):
        return 3
    # default: squared Euclidean
    return 0

class CTMCEmbedding:
    _OPTIMIZER_CODES = {"sgd": 0, "rmsprop": 1, "adam": 2, "adafactor": 3}

    def __init__(
        self,
        n_neighbors: int = 15,
        n_components: int = 2,
        attr_kernel: Union[str, "sp.Expr", None] = None,
        rep_kernel: Union[str, "sp.Expr", None] = None,
        graph_kernel: Union[str, "sp.Expr", None] = None,
        kernel_params: Optional[Dict[str, float]] = None,
        optimizer: str = "sgd",
        negative_sample_rate: float = 5.0,
        gradient_clip: Optional[float] = None,
        random_state: Optional[int] = None,
        verbose: bool = True,
        repulsion: float = 500.0,
        # --- New/graph/init params ---
        graph_backend: str = "hnswlib",
        symmetrize: str = "mean",
        init_mode: str = "standard",
        component_strategy: str = "pack",
        spectral_backend: str = "amg",
        spectral_tol: float = 1e-5,
        block_size: int = 256,
        num_threads: Optional[int] = None,
        deterministic: bool = False,
        trace_indices: Optional[np.ndarray] = None,
        trace_every: int = 0,
        metric: str = "euclidean",
        metric_params: Optional[dict] = None,
        graph_kernels: Optional[Dict[str, callable]] = None,
        embed_metric: str = "l2sqr",
        min_dist: Optional[float] = .2,
        spread: float = 1.0,
        rep_warmup_epochs: int = 500,  # 0 disables
        center_every: int = 1,       # 0 disables
        target_rms: float = 0.0,     # 0 => just recenter
    ):
        if optimizer not in self._OPTIMIZER_CODES:
            raise ValueError(f"optimizer must be one of {list(self._OPTIMIZER_CODES)}")
        if n_neighbors <= 0 or n_components <= 0:
            raise ValueError("n_neighbors and n_components must be positive")
        self.rep_warmup_epochs = int(rep_warmup_epochs)
        self.center_every = int(center_every)
        self.target_rms = float(target_rms)
        self.n_neighbors = int(n_neighbors)
        self.n_components = int(n_components)
        self.optimizer = optimizer
        self.negative_sample_rate = float(negative_sample_rate)
        self.gradient_clip = gradient_clip
        self.random_state = random_state if random_state is not None else 42
        self.verbose = bool(verbose)
        self.repulsion = float(repulsion)
        # Graph/init pipeline params
        self.graph_backend = graph_backend
        self.symmetrize = symmetrize

        self.init_mode = init_mode
        self.component_strategy = component_strategy
        self.spectral_backend = spectral_backend
        self.spectral_tol = spectral_tol
        self.block_size = int(block_size)
        self.num_threads = num_threads
        self.deterministic = bool(deterministic)
        self.trace_indices = None if trace_indices is None else np.asarray(trace_indices, dtype=np.int32)
        self.trace_every = int(trace_every)

        # --- kernel specs (allow None -> use defaults) ---
        self.attr_kernel  = attr_kernel
        self.rep_kernel   = rep_kernel
        self.graph_kernel = graph_kernel
        self.kernel_params = dict(kernel_params) if kernel_params is not None else {"a": 1.0, "b": 1.0}
        # ensure both params exist
        if "a" not in self.kernel_params: self.kernel_params["a"] = 1.0
        if "b" not in self.kernel_params: self.kernel_params["b"] = 1.0

        
        self.trace_ = None
        self.trace_epochs_ = None
        self.embedding_ = None
        self.n_samples_ = None
        self._is_fitted = False

        self.metric = metric
        self.metric_params = metric_params if metric_params is not None else {}
        self.graph_kernels = graph_kernels or {}

        # Embedding metric (optimizer), independent of graph metric
        self.embed_metric = embed_metric   # options: 'l2sqr','l2','l1','lp','correlation'
        self.embed_p = 2.0
        self.embed_eps = 1e-12
        self.embed_metric_fn = None   # user may set a custom sgrad function

        # --- UMAP-like (a,b) fitting knobs (optional; no effect if None) ---
        self.min_dist = min_dist
        self.spread = float(spread)


    def _setup_optimizer_buffers(self, n_samples, n_components):
        clip_val = 0.0 if self.gradient_clip is None else float(self.gradient_clip)

        # NEW: extras (with safe defaults if attrs don't exist)
        ee     = float(getattr(self, "early_exaggeration", 1.0))  # 1.0 = disabled
        ee_ep  = float(getattr(self, "exaggeration_epochs", 0))   # 0   = disabled
        neg_c  = float(int(getattr(self, "neg_candidates", 1)))   # 1   = vanilla negatives
        shuf   = 1.0 if bool(getattr(self, "shuffle_edges", True)) else 0.0

        extras = (ee, ee_ep, neg_c, shuf)

        if self.optimizer == "adam":
            buf1 = np.zeros((n_samples, n_components), np.float32)
            buf2 = np.zeros((n_samples, n_components), np.float32)
            hps = (0.9, 0.98, 1e-8, clip_val) + extras
            return (buf1, buf2, hps)

        elif self.optimizer == "rmsprop":
            buf = np.zeros((n_samples, n_components), np.float32)
            hps = (0.95, 1e-6, 0.0, clip_val) + extras
            return (buf, buf, hps)

        elif self.optimizer == "adafactor":
            buf = np.zeros((n_samples, n_components), np.float32)
            hps = (0.95, 1e-6, 0.0, clip_val) + extras
            return (buf, buf, hps)

        else:  # sgd
            # existing: hp1=rep_warmup_epochs, hp2=center_every, hp3=target_rms, hp4=clip_norm
            rep_warm = float(getattr(self, "rep_warmup_epochs", 0) or 0)
            center_ev = float(getattr(self, "center_every", 0) or 0)
            target_r = float(getattr(self, "target_rms", 0.0) or 0.0)
            dummy = np.zeros((n_samples, n_components), np.float32)
            hps = (rep_warm, center_ev, target_r, clip_val) + extras
            return (dummy, dummy, hps)
                
    def _select_kernels_and_params(self):
        """
            Decide kernels for graph build and optimizer.

            - graph_fn: Q(s_graph) used by the kNN graph builder (defaults to default_kernel on r^2).
            - attr_fn:  f_attr(s) =  d/ds log Q(s) for optimizer (s-based; defaults to _ht_dlogQ_s).
            - rep_fn:   f_rep(s)  = - d/ds      Q(s) for optimizer (s-based; defaults to _ht_negdQ_s).

            If a built-in family name is used, defer to make_builtin_kernel and pass back its params.
        """
        a = float(self.kernel_params.get("a", 1.0)) if hasattr(self, "kernel_params") else 1.0
        b = float(self.kernel_params.get("b", 1.0)) if hasattr(self, "kernel_params") else 1.0
        params_arr = np.array([a, b], dtype=np.float32)
        free_syms = ("a", "b")

        # Check for built-in family in any slot
        family = None
        for obj in (self.graph_kernel, self.attr_kernel, self.rep_kernel):
            if isinstance(obj, str) and obj.lower() in builtin_kernel_names():
                family = obj.lower()
                break

        if family is not None:
            Q, dlogQ, negdQ, fam_params = make_builtin_kernel(family, self.kernel_params)

            # Allow explicit overrides per slot
            if not (self.graph_kernel is None or (isinstance(self.graph_kernel, str) and self.graph_kernel.lower() == family)):
                Q = _resolve_force(self.graph_kernel, Q, free_syms, "Q")
            if not (self.attr_kernel is None or (isinstance(self.attr_kernel, str) and self.attr_kernel.lower() == family)):
                dlogQ = _resolve_force(self.attr_kernel, dlogQ, free_syms, "dlogQ")
            if not (self.rep_kernel is None or (isinstance(self.rep_kernel, str) and self.rep_kernel.lower() == family)):
                negdQ = _resolve_force(self.rep_kernel, negdQ, free_syms, "negdQ")

            return Q, dlogQ, negdQ, fam_params

        # Generic path (defaults/callables/expressions)
        graph_fn = _resolve_force(self.graph_kernel, default_kernel, free_syms, "Q")
        attr_fn  = _resolve_force(self.attr_kernel,  _ht_dlogQ_s,   free_syms, "dlogQ")
        rep_fn   = _resolve_force(self.rep_kernel,   _ht_negdQ_s,   free_syms, "negdQ")

        return graph_fn, attr_fn, rep_fn, params_arr

    
    def fit(self, X, y=None):
        X = np.asarray(X, dtype=np.float32)
        if X.ndim != 2:
            raise ValueError("X must be 2-D")
        if X.shape[0] < self.n_neighbors + 1:
            raise ValueError("n_samples must be >= n_neighbors + 1")

        if self.random_state is not None:
            np.random.seed(int(self.random_state))

        self.n_samples_ = X.shape[0]

        ei, ej, weights, Y, params, f_attr, f_rep, row_scale = self._prepare_graph_and_forces(X)
        buf1, buf2, hps = self._setup_optimizer_buffers(Y.shape[0], Y.shape[1])
        opt_code = self._OPTIMIZER_CODES[self.optimizer]

        self.embedding_ = Y
        self._graph_data = (
            ei, ej, weights, params, f_attr, f_rep,
            buf1, buf2, np.array(hps, dtype=np.float32), opt_code, row_scale
        )
        self._is_fitted = True
        return self
    
    def fit_transform_from_connectivities(self, X, connectivities, *, n_epochs=500, learning_rate=0.01):
        """
        Train CTMC using a precomputed adjacency (e.g., Scanpy/UMAP's fuzzy connectivities).
        - X: array (n_samples, d_in) used only for spectral init shape and component packing.
        - connectivities: scipy.sparse (n x n), typically ad.obsp['connectivities'].
        """
        # --- validate X ---
        X = np.asarray(X, dtype=np.float32)
        if X.ndim != 2:
            raise ValueError("X must be 2-D")

        try:
            # lazy import to avoid hard dependency unless used
            import scipy.sparse as sp_sparse
        except Exception as e:
            raise ImportError("scipy is required for fit_transform_from_connectivities") from e

        C = connectivities.tocsr()
        n = C.shape[0]
        if n != X.shape[0]:
            raise ValueError(f"connectivities shape {C.shape} mismatch with X {X.shape[0]} samples")

        # --- extract edges and weights ---
        ei, ej = C.nonzero()
        P = C.data.astype(np.float32, copy=False)
        ei = np.asarray(ei, dtype=np.int32)
        ej = np.asarray(ej, dtype=np.int32)

        # Use same normalization as regular path
        weights = P / (float(P.max()) if P.size > 0 else 1.0)

        # Repulsion calibration
        row_sum = np.bincount(ei, weights=weights, minlength=n).astype(np.float32)
        nsr = float(self.negative_sample_rate)
        alpha = float(self.repulsion)
        row_scale = ((n - 1) / n) / (nsr * (row_sum + 1e-12))
        row_scale *= alpha
        row_scale.flags.writeable = False

        # Kernels + params (graph kernel is irrelevant here; use optimizer ones)
        graph_fn, f_attr, f_rep, params = self._select_kernels_and_params()

        # Spectral init on the same graph (P, ei, ej)
        Y = init_spectral(
            X, P, ei, ej, n=n,
            init_mode=self.init_mode,
            component_strategy=self.component_strategy,
            backend=self.spectral_backend,
            tol=self.spectral_tol
        ).astype(np.float32, copy=False)
        Y = np.ascontiguousarray(Y[:, : self.n_components])

        # Optional (a,b) fitting from min_dist/spread; no effect if min_dist is None
        params_opt = params.copy()
        if self.min_dist is not None:
            a_fit, b_fit = _fit_ab_from_min_dist_heavytail(self.spread, float(self.min_dist), self.embed_metric)
            params_opt[0] = a_fit
            params_opt[1] = b_fit
            if self.verbose:
                print(f"Fitted (a,b) ≈ ({float(a_fit):.4g}, {float(b_fit):.4g}) from min_dist={self.min_dist}, spread={self.spread}")

        # Buffers & state
        buf1, buf2, hps = self._setup_optimizer_buffers(Y.shape[0], Y.shape[1])
        opt_code = self._OPTIMIZER_CODES[self.optimizer]

        self.embedding_ = Y
        self._graph_data = (
            ei, ej, weights, params_opt, f_attr, f_rep,
            buf1, buf2, np.array(hps, dtype=np.float32), opt_code, row_scale
        )
        self._is_fitted = True

        # Run epochs
        return self.transform(X=None, n_epochs=n_epochs, learning_rate=learning_rate)

    def fit_transform(self, X, y=None, **kwargs):
        return self.fit(X, y).transform(X, **kwargs)

    def set_params(self, **params):
        for k, v in params.items():
            if not hasattr(self, k):
                raise ValueError(f"Invalid parameter {k}")
            setattr(self, k, v)
        return self


    def transform(self, X=None, *, n_epochs=500, learning_rate=0.01):
        if not self._is_fitted:
            raise ValueError("CTMCEmbedding not fitted yet.")

        (ei, ej, weights, params, f_attr, f_rep,
        buf1, buf2, hp_arr, opt_code, row_scale) = self._graph_data

        # --- scheduler ---
        ui = _compute_intervals(weights, int(n_epochs))
        uir = ui / float(self.negative_sample_rate)
        us = ui.copy()
        usr = ui.copy()
        for arr in (ui, uir):
            arr.setflags(write=False)

        # --- tracing setup ---
        dim = self.embedding_.shape[1]
        if self.trace_every > 0 and self.trace_indices is not None and self.trace_indices.size > 0:
            snap_epochs = np.arange(0, int(n_epochs), int(self.trace_every), dtype=np.int32)
            snap_map = np.full(int(n_epochs), -1, dtype=np.int32)
            snap_map[snap_epochs] = np.arange(snap_epochs.shape[0], dtype=np.int32)
            trace_buf = np.empty((snap_epochs.shape[0], self.trace_indices.shape[0], dim), dtype=np.float32)
            self.trace_epochs_ = snap_epochs
            trace_idx = self.trace_indices.astype(np.int32, copy=False)
        else:
            snap_map = np.full(int(n_epochs), -1, dtype=np.int32)
            trace_buf = np.empty((0, 0, 0), dtype=np.float32)
            trace_idx = np.zeros((0,), dtype=np.int32)

        # --- threads / determinism ---
        if getattr(self, "num_threads", None) is not None:
            numba.set_num_threads(int(self.num_threads))
        elif self.deterministic:
            numba.set_num_threads(1)

        if self.verbose:
            print(f"Running {int(n_epochs)} epochs with '{self.optimizer}'…")

        if n_epochs > 0:
            kernel = _get_epoch_kernel(
                self._OPTIMIZER_CODES[self.optimizer],
                dim=dim,
                clip_enabled=(self.gradient_clip is not None and float(self.gradient_clip) > 0.0),
            )
            # Resolve sgrad function (+ parameters) for the embedding metric
            em_sgrad_fn = _resolve_em_sgrad(getattr(self, "embed_metric", "l2sqr"),
                                            custom_fn=getattr(self, "embed_metric_fn", None))
            em_params = np.array([float(self.embed_eps), float(self.embed_p)], dtype=np.float32)

            kernel(
                self.embedding_, buf1, buf2,
                ei, ej, weights, row_scale,
                ui, us, uir, usr,
                params, f_attr, f_rep,
                np.array(hp_arr, dtype=np.float32),
                np.float32(learning_rate),
                int(n_epochs),
                int(self.block_size),
                int(self.random_state),
                # embedding metric function & params
                em_sgrad_fn, em_params,
                trace_idx, trace_buf, snap_map
            )

        self.graph_ = (ei, ej, weights) # Store graph 
        self.trace_ = trace_buf if trace_buf.size else None
        return self.embedding_.copy()
    
    def _prepare_graph_and_forces(self, X):
        if self.verbose: print("Building graph…")
        graph_fn, f_attr, f_rep, params = self._select_kernels_and_params()

        # Optional: override only the graph kernel per metric name (unchanged)
        if self.graph_kernels and (self.metric in self.graph_kernels):
            user_graph = self.graph_kernels[self.metric]
            graph_fn = _resolve_force(
                user_graph, graph_fn,
                tuple(sorted(self.kernel_params.keys())) if self.kernel_params else ("a", "b"),
                "Q"
            )
         

        metric_code = _metric_to_code(self.metric)

        # ---- Stage A: always build the BASE kNN graph (mid_shortcuts=False) ----
        rho, sigmas, ei0, ej0, P0, _ = build_weighted_graph(
            X, k=self.n_neighbors + 1,
            symmetrize=self.symmetrize,
            backend=self.graph_backend,
            kernel_function=graph_fn,
            kernel_params=params,
            seed=int(self.random_state),
            metric=self.metric,
            metric_params=self.metric_params,
        )
        ei0 = np.ascontiguousarray(ei0, dtype=np.int32)
        ej0 = np.ascontiguousarray(ej0, dtype=np.int32)
        P0  = np.ascontiguousarray(P0,  dtype=np.float32)

        n = X.shape[0]

        # ---- Stage B: compute labels only if guided mid-mode ----
        labels = None


        # Spectral init on the final graph (unchanged)
        Y = init_spectral(
            X, P0, ei0, ej0, n=X.shape[0],
            init_mode=self.init_mode,
            component_strategy=self.component_strategy,
            backend=self.spectral_backend,
            tol=self.spectral_tol
        ).astype(np.float32, copy=False)
        Y = np.ascontiguousarray(Y[:, : self.n_components])
        # ---- Stage C: apply mid-shortcuts (guided uses 'labels') ----
        # We add mid-edges here directly to keep this method self-contained.
        
        ei, ej, P = ei0, ej0, P0

        # --- Normalize weights (unchanged) ---
        weights = P / (float(P.max()) if P.size > 0 else 1.0)

        # Expected repulsion calibration (unchanged)
        row_sum = np.bincount(ei, weights=weights, minlength=n).astype(np.float32)
        nsr = float(self.negative_sample_rate)
        alpha = float(self.repulsion)
        row_scale = ((n - 1) / n) / (nsr * (row_sum + 1e-12))
        row_scale *= alpha
        row_scale.flags.writeable = False

        if self.verbose:
            target = float(np.mean(row_scale * nsr * row_sum))
            print("target ≈ α*(n-1)/n:", target)
            print("Initializing embeddings…")

        

        # Optional (a,b) fit (unchanged)
        params_opt = params.copy()
        if self.min_dist is not None:
            a_fit, b_fit = _fit_ab_from_min_dist_heavytail(self.spread, float(self.min_dist), self.embed_metric)
            params_opt[0] = a_fit
            params_opt[1] = b_fit
            if self.verbose:
                print(f"Fitted (a,b) ≈ ({float(a_fit):.4g}, {float(b_fit):.4g}) from min_dist={self.min_dist}, spread={self.spread}")

        return ei, ej, weights, Y, params_opt, f_attr, f_rep, row_scale

    def get_trace(self):
        """Return (trace, epochs) or (None, None) if tracing disabled."""
        return self.trace_, self.trace_epochs_

    def get_params(self, deep=True):
        # Shit may be missing... idk, fix later! 
        keys = [
            "n_neighbors","n_components","attr_kernel","rep_kernel","graph_kernel",
            "kernel_params","optimizer","negative_sample_rate","gradient_clip",
            "random_state","verbose","repulsion",
            # graph/init
            "graph_backend","symmetrize",
            "init_mode","component_strategy","spectral_backend","spectral_tol",
            # engine/runtime
            "block_size","num_threads","deterministic",
            # tracing
            "trace_indices","trace_every",
            "embed_metric","embed_p","embed_eps",
            "min_dist","spread",
        ]
        return {k: getattr(self, k) for k in keys}

if __name__ == "__main__":
    import time
    import numpy as np
    import matplotlib.pyplot as plt
    from sklearn.datasets import load_digits, make_moons, make_swiss_roll, load_breast_cancer, load_wine, load_diabetes
    from sklearn.preprocessing import MinMaxScaler, StandardScaler
    from sklearn.decomposition import PCA
    # 1) Load and preprocess digits (1797 samples, 64 dims)
    # X, y = load_digits(return_X_y=True)
    # X, y = load_breast_cancer(return_X_y=True)
    # X, y = load_wine(return_X_y=True)
    X, y =make_moons(8000, noise=0.08)
    # X, y = make_swiss_roll(10000, hole=False)

    # X = MinMaxScaler().fit_transform(X) 
    X = StandardScaler().fit_transform(X)  # [0, 1] scale for stability/consistency

    # X, y = make_swiss_roll(n_samples=10000, noise=0.08, hole=True)
    X = X.astype(np.float32)

    # Optional: PCA to 50D for speed (good default on digits)
    # X_50 = PCA(n_components=50, svd_solver="randomized", random_state=0)\
            # .fit_transform(X).astype(np.float32, copy=False)
    X_50 = X 
    # 2) Configure CTMC
    # - optimizer: "adam" tends to converge quickly
    # - negative_sample_rate (nsr): a bit higher improves cluster separation
    # - repulsion:
    #     * smooth embedding:   repulsion ≈ nsr
    #     * stronger clustering repulsion ≈ nsr^2
    rng = np.random.default_rng(42)
    trace_idx = rng.choice(X_50.shape[0], size=min(400, X_50.shape[0]), replace=False).astype(np.int32)

    model = CTMCEmbedding(
        n_neighbors=15,
        n_components=2,
        optimizer="sgd",
        negative_sample_rate=25.0,
        gradient_clip=0.0,         # set None or 0 to disable; kernel will specialize
        random_state=42,
        verbose=True,
        repulsion=25.0**2,            # try 225.0 (15^2) if you want tighter clusters
        # graph/init
        graph_backend="pynndescent",
        init_mode="standard",       # very fast, great default for graphs
        spectral_backend="auto",
        spectral_tol=1e-6,
        # engine/runtime
        block_size=256,
        num_threads=None,          # set to an int to pin threads; or deterministic=True for 1 thread
        deterministic=False,
        # tracing (optional)
        trace_indices=trace_idx,   # set to None to disable
        trace_every=50,            # snapshot every 20 epochs; set 0 to disable
        embed_metric="l2sqr",
        min_dist=.5 
    )

    from umap import UMAP 

    # model = UMAP(15)


    # 3) Fit + transform
    n_epochs = 1000
    lr = .1

    t0 = time.perf_counter()
    Y = model.fit_transform(X_50, n_epochs=n_epochs, learning_rate=lr)
    # Y = model.fit_transform(X_50)
    t1 = time.perf_counter()
    print(f"[CTMC] digits → Y.shape={Y.shape}, time={t1 - t0:.2f}s")

    # 4) Plot embedding
    fig, ax = plt.subplots(1, 1, figsize=(6, 5), constrained_layout=True)
    sc = ax.scatter(Y[:, 0], Y[:, 1], c=y, s=8, cmap="Spectral", edgecolors="none")
    ax.set_xticks([]); ax.set_yticks([])
    ax.set_title("CTMC on sklearn digits (n=1797)")
    plt.show()

    # 5) (Optional) Plot trace snapshots if tracing was enabled
    trace, epochs = model.get_trace()
    if trace is not None and epochs is not None and trace.size:
        n_snaps = trace.shape[0]
        cols = min(5, n_snaps)
        rows = int(np.ceil(n_snaps / cols))
        fig, axes = plt.subplots(rows, cols, figsize=(3.2 * cols, 3.2 * rows), squeeze=False, constrained_layout=True)
        for i in range(n_snaps):
            r = i // cols
            c = i % cols
            axes[r, c].scatter(trace[i, :, 0], trace[i, :, 1], s=6, c="k", alpha=0.6)
            axes[r, c].set_title(f"epoch {int(epochs[i])}")
            axes[r, c].set_xticks([]); axes[r, c].set_yticks([])
        # Hide any empty subplots
        for j in range(n_snaps, rows * cols):
            r = j // cols; c = j % cols
            axes[r, c].axis("off")
        plt.show()
